package com.ezlcp.gateway.filter;

import com.alibaba.nacos.common.utils.StringUtils;
import com.ezlcp.commons.utils.RequestUtil;
import com.ezlcp.gateway.utils.XssCleanRuleUtils;
import io.netty.buffer.ByteBufAllocator;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * @author Elwin ZHANG
 * @description: 跨域脚本攻击请求参数全局过滤器<br />
 * @date 2023/2/1 14:31
 */
@Component
@Slf4j
public class XssRequestGlobalFilter implements GlobalFilter, Ordered {
    @Value("#{'${ezlcp.ignore.xss.urls}'.trim().split(',')}")
    private List<String> ignoreUrls;

    /**
     * get请求参考spring cloud gateway自带过滤器：
     *
     * @see org.springframework.cloud.gateway.filter.factory.AddRequestParameterGatewayFilterFactory
     * <p>
     * post请求参考spring cloud gateway自带过滤器：
     * @see org.springframework.cloud.gateway.filter.factory.rewrite.ModifyRequestBodyGatewayFilterFactory
     */
    @Override
    @SneakyThrows
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        log.debug("----自定义防XSS攻击网关全局过滤器生效----");
        ServerHttpRequest req = exchange.getRequest();
        HttpMethod method = req.getMethod();
        String contentType = req.getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);
        //当前URL是否配置在XSS忽略名单中
        String reqUrl = req.getPath().value();
        if (RequestUtil.containsUrl(reqUrl,ignoreUrls)) {
            return chain.filter(exchange);
        }
        Boolean postFlag = (HttpMethod.POST.equals(method) || HttpMethod.PUT.equals(method)) &&
                (MediaType.APPLICATION_FORM_URLENCODED_VALUE.equalsIgnoreCase(contentType) || MediaType.APPLICATION_JSON_VALUE.equals(contentType) || MediaType.APPLICATION_JSON_UTF8_VALUE.equals(contentType));

        // get 请求， 参考的是 org.springframework.cloud.gateway.filter.factory.AddRequestParameterGatewayFilterFactory
        if (HttpMethod.GET.equals(method) || HttpMethod.DELETE.equals(method)) {
            URI uri = req.getURI();

            String rawQuery = uri.getRawQuery();
            if (StringUtils.isBlank(rawQuery)) {
                return chain.filter(exchange);
            }
            rawQuery = XssCleanRuleUtils.xssClean(rawQuery);
            try {
                URI newUri = UriComponentsBuilder.fromUri(uri)
                        .replaceQuery(rawQuery)
                        .build(true)
                        .toUri();
                ServerHttpRequest request = req.mutate()
                        .uri(newUri).build();
                return chain.filter(exchange.mutate().request(request).build());
            } catch (Exception e) {
                log.error("get请求清理xss攻击异常", e);
                throw new IllegalStateException("Invalid URI query: \"" + rawQuery + "\"");
            }
        }
        //post请求时，如果是文件上传之类的请求，不修改请求消息体
        else if (postFlag) {
            long contentLength=req.getHeaders().getContentLength();
            if(contentLength==0){
                return chain.filter(exchange);
            }
            return DataBufferUtils.join(exchange.getRequest().getBody()).flatMap(dataBuffer -> {
                byte[] bytes = new byte[dataBuffer.readableByteCount()];
                dataBuffer.read(bytes);
                String postRequestBodyStr = new String(bytes, StandardCharsets.UTF_8);
                //过滤body
                String newBodyStr="";
                if (!StringUtils.isBlank(postRequestBodyStr)) {
                    newBodyStr = XssCleanRuleUtils.xssClean(postRequestBodyStr);
                }
                // 由于修改了传递参数，需要重新设置CONTENT_LENGTH，长度是字节长度，不是字符串长度
                int length = 0;
                try {
                    length = newBodyStr.getBytes("utf-8").length;
                } catch (UnsupportedEncodingException e) {
                    e.printStackTrace();
                }
                URI uri = req.getURI();
                URI newUri = UriComponentsBuilder.fromUri(uri).build(true).toUri();
                ServerHttpRequest request = req.mutate().uri(newUri).build();
                DataBuffer bodyDataBuffer = stringBuffer(newBodyStr);
                Flux<DataBuffer> bodyFlux = Flux.just(bodyDataBuffer);

                // 定义新的消息头
                HttpHeaders headers = new HttpHeaders();
                headers.putAll(req.getHeaders());

                // 由于修改了传递参数，需要重新设置CONTENT_LENGTH，长度是字节长度，不是字符串长度
                headers.remove(HttpHeaders.CONTENT_LENGTH);
                headers.setContentLength(length);

                // 设置CONTENT_TYPE
                if (StringUtils.isNotBlank(contentType)) {
                    headers.set(HttpHeaders.CONTENT_TYPE, contentType);
                }
                // 由于post的body只能订阅一次，由于上面代码中已经订阅过一次body。所以要再次封装请求到request才行，不然会报错请求已经订阅过
                request = new ServerHttpRequestDecorator(request) {
                    @Override
                    public HttpHeaders getHeaders() {
                        long contentLength = headers.getContentLength();
                        HttpHeaders httpHeaders = new HttpHeaders();
                        httpHeaders.putAll(super.getHeaders());
                        if (contentLength > 0) {
                            httpHeaders.setContentLength(contentLength);
                        } else {
                            // this causes a 'HTTP/1.1 411 Length Required' on httpbin.org
                            httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                        }
                        return httpHeaders;
                    }

                    @Override
                    public Flux<DataBuffer> getBody() {
                        return bodyFlux;
                    }
                };

                //封装request，传给下一级
                request.mutate().header(HttpHeaders.CONTENT_LENGTH, Integer.toString(length));
                return chain.filter(exchange.mutate().request(request).build());
            });
        } else {
            return chain.filter(exchange);
        }

    }

    @Override
    public int getOrder() {
        return Ordered.HIGHEST_PRECEDENCE + 999;
    }
    /**
     * 字符串转DataBuffer
     *
     * @param value
     * @return
     */
    private DataBuffer stringBuffer(String value) {
        byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
        NettyDataBufferFactory nettyDataBufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT);
        DataBuffer buffer = nettyDataBufferFactory.allocateBuffer(bytes.length);
        buffer.write(bytes);
        return buffer;
    }

}